-
Notifications
You must be signed in to change notification settings - Fork 31
Add U8 copy operation for K16 MMA #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add U8 copy operation for K16 MMA #374
Conversation
# Conflicts: # include/cute/arch/xe_copy_1B.hpp # include/cute/arch/xe_copy_2B.hpp # include/cute/arch/xe_copy_4B.hpp
# Conflicts: # include/cute/arch/mma_xe.hpp
With FP8xFP8 GEMM, this config didn't work, but the corresponding code works for FP16xFP16 GEMM: using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;
using TileShape = Shape<_64, _256, _32>;
using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA; The compile-time error was
It seems to be a bug since the shapes are correct. Thanks! |
…ked-copy # Conflicts: # CMakeLists.txt # include/cute/arch/copy_xe_U16.hpp # include/cute/arch/copy_xe_U32.hpp # include/cute/arch/copy_xe_U4.hpp # include/cute/arch/copy_xe_U64.hpp # include/cute/arch/copy_xe_U8.hpp # include/cute/arch/copy_xe_builtin.hpp # include/cute/arch/copy_xe_spirv.hpp # include/cutlass/epilogue/collective/xe_epilogue.hpp
struct XE_2D_U8x32x32_LD_N { | ||
using BlockShape = Shape<_32, _32>; | ||
|
||
template <class T> | ||
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, | ||
int height, int pitch, intel::coord_t coord, | ||
T *dst) { | ||
#if defined(CUTE_ARCH_COPY_XE_ENABLED) | ||
static_assert(sizeof(T) == 1, "Expected T to have size 1"); | ||
// detail::XeSubgroup2DBlockLoad<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst); | ||
// Use the transform (VNNI) version as it provides better performance when loading the A matrix for | ||
// GEMM FP8 and GEMM mixed-precision types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @aacostadiaz,
Please help resolve a couple of doubts.
The DstLayout
in atom traits for this copy atom is Layout<Shape <_16,Shape <_8, _2, _32>>, Stride<_16,Stride< _1,_128,_256>>>;
, which seems to correspond to plain layout. So, does this mean that initially, when the data would be copied from global memory, it'd be transformed into VNNI layout before writing to the registers, and would later be converted to DstLayout
? If yes, can you please point out where/how it's handled in the code?
Also, I don't see any shfl
based instructions in the generated assembly dump, so is it possible that the shuffle (for VNNI -> plain layout conversion) may not be happening directly via lane registers -> lane registers
(I understand this isn't possible on Nvidia GPUs, but is somehow possible on Intel GPUs, based on the documentation) but lane registers -> shared local memory -> lane registers
?
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Copy trait is used to describe how a copy operation works so that the rest of the code can understand it. It does not change how the actual copy operation works.
In this case, for the VNNI copies the transformation happens inside the builtin/spirv function. There is no transformation inside cutlass for that. We just use these builtin/spirv functions and the copy traits describe how these functions work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aacostadiaz, thanks, but I meant that since A
for FP8
GEMM is being loaded in VNNI layout in this PR, and the GEMM output is correct, that seems to suggest that the layout must've been changed from VNNI to plain somewhere in the code.
In this case, for the VNNI copies the transformation happens inside the builtin/spirv function
Sorry, do you mean the VNNI -> plain transformation
also happens inside the builtin? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, XeSubgroup2DBlockLoad<1, 16, 32, 2>
and XeSubgroup2DBlockLoadTransform<1, 16, 32, 2>
(Transform is VNNI transformation) are loading the exact same data and we end up with the exact same values in the registers. The only difference with XeSubgroup2DBlockLoadTransform<1, 16, 32, 2>
is that the packing is 32 bits, so we get 32-bit elements out of the copy operation. If you recast this into four 8-bit elements you have the exact same information as with the XeSubgroup2DBlockLoad<1, 16, 32, 2>
copy
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, since we load the data column-wise (from the POV of one work-item) with XeSubgroup2DBlockLoad
, anyway, it doesn't matter whether we use XeSubgroup2DBlockLoadTransform
or XeSubgroup2DBlockLoad
(I haven't yet reasoned about whether or not it'd work for all relevant tile shapes, though. I'll do that later).
From https://github.khronos.org/SPIRV-Registry/extensions/INTEL/SPV_INTEL_2d_block_io.html,

Hi @aacostadiaz , vLLM team is blocked by this issue. Would you please prioritize this and merge this into the main branch? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsure about the Layout for the new operation, which looks like it might relate to @sanchitintel's comment.
Aside from that, just a nit suggestion.
include/cute/atom/copy_traits_xe.hpp
Outdated
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _32>>, | ||
Stride< _0,Stride< _1,_128,_256>>>; | ||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape <_16,Shape <_8, _2, _32>>, | ||
Stride<_16,Stride< _1,_128,_256>>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like XE_2D_Packed_U8x32x32_LD_N
and XE_2D_U8x32x32_LD_N
have the same *Layout
traits. Is that expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check out the copy_debug
tool to verify why they look similar (they were same when you commented) & will report back with any findings. Thanks!
Co-authored-by: Joe Todd <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
@@ -535,7 +535,7 @@ int main(int argc, const char** argv) | |||
using ElementScale = MmaType; | |||
|
|||
// Note: XE_2D_U18x32x32_LD_N is incompatible with our bf16 MMA atoms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this comment seems to be obsolete now.
# Conflicts: # include/cute/atom/copy_traits_xe.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks.
This PR adds the U8 copy operation that works correctly with the K16 MMA for FP8 GEMM or mixed dtype GEMM.